import json
from time import time
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from pydantic import BaseModel, ConfigDict, Field

from agno.media import Audio, File, Image, Video
from agno.models.metrics import Metrics
from agno.utils.log import log_debug, log_error, log_info, log_warning


class MessageReferences(BaseModel):
    """References added to user message"""

    # The query used to retrieve the references.
    query: str
    # References (from the vector database or function calls)
    references: Optional[List[Union[Dict[str, Any], str]]] = None
    # Time taken to retrieve the references.
    time: Optional[float] = None


class UrlCitation(BaseModel):
    """URL of the citation"""

    url: Optional[str] = None
    title: Optional[str] = None


class DocumentCitation(BaseModel):
    """Document of the citation"""

    document_title: Optional[str] = None
    cited_text: Optional[str] = None
    file_name: Optional[str] = None


class Citations(BaseModel):
    """Citations for the message"""

    # Raw citations from the model
    raw: Optional[Any] = None

    # URLs of the citations.
    urls: Optional[List[UrlCitation]] = None

    # Document Citations
    documents: Optional[List[DocumentCitation]] = None


class Message(BaseModel):
    """Message sent to the Model"""

    id: str = Field(default_factory=lambda: str(uuid4()))

    # The role of the message author.
    # One of system, user, assistant, or tool.
    role: str
    # The contents of the message.
    content: Optional[Union[List[Any], str]] = None
    # An optional name for the participant.
    # Provides the model information to differentiate between participants of the same role.
    name: Optional[str] = None
    # Tool call that this message is responding to.
    tool_call_id: Optional[str] = None
    # The tool calls generated by the model, such as function calls.
    tool_calls: Optional[List[Dict[str, Any]]] = None

    # Additional modalities
    audio: Optional[Sequence[Audio]] = None
    images: Optional[Sequence[Image]] = None
    videos: Optional[Sequence[Video]] = None
    files: Optional[Sequence[File]] = None

    # Output from the models
    audio_output: Optional[Audio] = None
    image_output: Optional[Image] = None
    video_output: Optional[Video] = None
    file_output: Optional[File] = None

    # The thinking content from the model
    redacted_reasoning_content: Optional[str] = None

    # Data from the provider we might need on subsequent messages
    provider_data: Optional[Dict[str, Any]] = None

    # Citations received from the model
    citations: Optional[Citations] = None

    # --- Data not sent to the Model API ---
    # The reasoning content from the model
    reasoning_content: Optional[str] = None
    # The name of the tool called
    tool_name: Optional[str] = None
    # Arguments passed to the tool
    tool_args: Optional[Any] = None
    # The error of the tool call
    tool_call_error: Optional[bool] = None
    # If True, the agent will stop executing after this tool call.
    stop_after_tool_call: bool = False
    # When True, the message will be added to the agent's memory.
    add_to_agent_memory: bool = True
    # This flag is enabled when a message is fetched from the agent's memory.
    from_history: bool = False
    # Metrics for the message.
    metrics: Metrics = Field(default_factory=Metrics)
    # The references added to the message for RAG
    references: Optional[MessageReferences] = None
    # The Unix timestamp the message was created.
    created_at: int = Field(default_factory=lambda: int(time()))

    model_config = ConfigDict(extra="allow", populate_by_name=True, arbitrary_types_allowed=True)

    def get_content_string(self) -> str:
        """Returns the content as a string."""
        if isinstance(self.content, str):
            return self.content
        if isinstance(self.content, list):
            if len(self.content) > 0 and isinstance(self.content[0], dict) and "text" in self.content[0]:
                return self.content[0].get("text", "")
            else:
                return json.dumps(self.content)
        return ""

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Message":
        # Handle image reconstruction properly
        if "images" in data and data["images"]:
            reconstructed_images = []
            for i, img_data in enumerate(data["images"]):
                if isinstance(img_data, dict):
                    # If content is base64, decode it back to bytes
                    if "content" in img_data and isinstance(img_data["content"], str):
                        reconstructed_images.append(
                            Image.from_base64(
                                img_data["content"],
                                id=img_data.get("id"),
                                mime_type=img_data.get("mime_type"),
                                format=img_data.get("format"),
                            )
                        )
                    else:
                        # Regular image (filepath/url)
                        reconstructed_images.append(Image(**img_data))
                else:
                    reconstructed_images.append(img_data)
            data["images"] = reconstructed_images

        # Handle audio reconstruction properly
        if "audio" in data and data["audio"]:
            reconstructed_audio = []
            for i, aud_data in enumerate(data["audio"]):
                if isinstance(aud_data, dict):
                    # If content is base64, decode it back to bytes
                    if "content" in aud_data and isinstance(aud_data["content"], str):
                        reconstructed_audio.append(
                            Audio.from_base64(
                                aud_data["content"],
                                id=aud_data.get("id"),
                                mime_type=aud_data.get("mime_type"),
                                transcript=aud_data.get("transcript"),
                                expires_at=aud_data.get("expires_at"),
                                sample_rate=aud_data.get("sample_rate", 24000),
                                channels=aud_data.get("channels", 1),
                            )
                        )
                    else:
                        reconstructed_audio.append(Audio(**aud_data))
                else:
                    reconstructed_audio.append(aud_data)
            data["audio"] = reconstructed_audio

        # Handle video reconstruction properly
        if "videos" in data and data["videos"]:
            reconstructed_videos = []
            for i, vid_data in enumerate(data["videos"]):
                if isinstance(vid_data, dict):
                    # If content is base64, decode it back to bytes
                    if "content" in vid_data and isinstance(vid_data["content"], str):
                        reconstructed_videos.append(
                            Video.from_base64(
                                vid_data["content"],
                                id=vid_data.get("id"),
                                mime_type=vid_data.get("mime_type"),
                                format=vid_data.get("format"),
                            )
                        )
                    else:
                        reconstructed_videos.append(Video(**vid_data))
                else:
                    reconstructed_videos.append(vid_data)
            data["videos"] = reconstructed_videos

        # Handle file reconstruction properly
        if "files" in data and data["files"]:
            reconstructed_files = []
            for i, file_data in enumerate(data["files"]):
                if isinstance(file_data, dict):
                    # If content is base64, decode it back to bytes
                    if "content" in file_data and isinstance(file_data["content"], str):
                        reconstructed_files.append(
                            File.from_base64(
                                file_data["content"],
                                id=file_data.get("id"),
                                mime_type=file_data.get("mime_type"),
                                filename=file_data.get("filename"),
                                name=file_data.get("name"),
                                format=file_data.get("format"),
                            )
                        )
                    else:
                        reconstructed_files.append(File(**file_data))
                else:
                    reconstructed_files.append(file_data)
            data["files"] = reconstructed_files

        if "audio_output" in data and data["audio_output"]:
            aud_data = data["audio_output"]
            if isinstance(aud_data, dict):
                if "content" in aud_data and isinstance(aud_data["content"], str):
                    data["audio_output"] = Audio.from_base64(
                        aud_data["content"],
                        id=aud_data.get("id"),
                        mime_type=aud_data.get("mime_type"),
                        transcript=aud_data.get("transcript"),
                        expires_at=aud_data.get("expires_at"),
                        sample_rate=aud_data.get("sample_rate", 24000),
                        channels=aud_data.get("channels", 1),
                    )
                else:
                    data["audio_output"] = Audio(**aud_data)

        if "image_output" in data and data["image_output"]:
            img_data = data["image_output"]
            if isinstance(img_data, dict):
                if "content" in img_data and isinstance(img_data["content"], str):
                    data["image_output"] = Image.from_base64(
                        img_data["content"],
                        id=img_data.get("id"),
                        mime_type=img_data.get("mime_type"),
                        format=img_data.get("format"),
                    )
                else:
                    data["image_output"] = Image(**img_data)

        if "video_output" in data and data["video_output"]:
            vid_data = data["video_output"]
            if isinstance(vid_data, dict):
                if "content" in vid_data and isinstance(vid_data["content"], str):
                    data["video_output"] = Video.from_base64(
                        vid_data["content"],
                        id=vid_data.get("id"),
                        mime_type=vid_data.get("mime_type"),
                        format=vid_data.get("format"),
                    )
                else:
                    data["video_output"] = Video(**vid_data)

        return cls(**data)

    def to_dict(self) -> Dict[str, Any]:
        """Returns the message as a dictionary."""
        message_dict = {
            "id": self.id,
            "content": self.content,
            "reasoning_content": self.reasoning_content,
            "from_history": self.from_history,
            "stop_after_tool_call": self.stop_after_tool_call,
            "role": self.role,
            "name": self.name,
            "tool_call_id": self.tool_call_id,
            "tool_name": self.tool_name,
            "tool_args": self.tool_args,
            "tool_call_error": self.tool_call_error,
            "tool_calls": self.tool_calls,
            "redacted_reasoning_content": self.redacted_reasoning_content,
            "provider_data": self.provider_data,
        }
        # Filter out None and empty collections
        message_dict = {
            k: v for k, v in message_dict.items() if v is not None and not (isinstance(v, (list, dict)) and len(v) == 0)
        }

        # Convert media objects to dictionaries
        if self.images:
            message_dict["images"] = [img.to_dict() for img in self.images]
        if self.audio:
            message_dict["audio"] = [aud.to_dict() for aud in self.audio]
        if self.videos:
            message_dict["videos"] = [vid.to_dict() for vid in self.videos]
        if self.files:
            message_dict["files"] = [file.to_dict() for file in self.files]
        if self.audio_output:
            message_dict["audio_output"] = self.audio_output.to_dict()

        if self.references:
            message_dict["references"] = self.references.model_dump()
        if self.metrics:
            message_dict["metrics"] = self.metrics.to_dict()
            if not message_dict["metrics"]:
                message_dict.pop("metrics")

        message_dict["created_at"] = self.created_at
        return message_dict

    def to_function_call_dict(self) -> Dict[str, Any]:
        return {
            "content": self.content,
            "tool_call_id": self.tool_call_id,
            "tool_name": self.tool_name,
            "tool_args": self.tool_args,
            "tool_call_error": self.tool_call_error,
            "metrics": self.metrics,
            "created_at": self.created_at,
        }

    def log(self, metrics: bool = True, level: Optional[str] = None):
        """Log the message to the console

        Args:
            metrics (bool): Whether to log the metrics.
            level (str): The level to log the message at. One of debug, info, warning, or error.
                Defaults to debug.
        """
        _logger = log_debug
        if level == "info":
            _logger = log_info
        elif level == "warning":
            _logger = log_warning
        elif level == "error":
            _logger = log_error

        try:
            import shutil

            terminal_width = shutil.get_terminal_size().columns
        except Exception:
            terminal_width = 80  # fallback width

        header = f" {self.role} "
        _logger(f"{header.center(terminal_width - 20, '=')}")

        if self.name:
            _logger(f"Name: {self.name}")
        if self.tool_call_id:
            _logger(f"Tool call Id: {self.tool_call_id}")
        if self.reasoning_content:
            _logger(f"<reasoning>\n{self.reasoning_content}\n</reasoning>")
        if self.content:
            if isinstance(self.content, str) or isinstance(self.content, list):
                _logger(self.content)
            elif isinstance(self.content, dict):
                _logger(json.dumps(self.content, indent=2))
        if self.tool_calls:
            tool_calls_list = ["Tool Calls:"]
            for tool_call in self.tool_calls:
                tool_id = tool_call.get("id")
                function_name = tool_call.get("function", {}).get("name")
                tool_calls_list.append(f"  - ID: '{tool_id}'") if tool_id else None
                tool_calls_list.append(f"    Name: '{function_name}'") if function_name else None
                tool_call_arguments = tool_call.get("function", {}).get("arguments")
                if tool_call_arguments:
                    try:
                        tool_call_args: dict = (
                            tool_call_arguments
                            if isinstance(tool_call_arguments, dict)
                            else json.loads(tool_call_arguments)
                        )
                        if tool_call_args:
                            # Ensure tool_call_args is a dictionary before calling .items()
                            if isinstance(tool_call_args, dict):
                                arguments = ", ".join(f"{k}: {v}" for k, v in tool_call_args.items())
                                tool_calls_list.append(f"    Arguments: '{arguments}'")
                            else:
                                tool_calls_list.append(f"    Arguments: '{tool_call_args}'")
                    except json.JSONDecodeError:
                        tool_calls_list.append("    Arguments: 'Invalid JSON format'")
            tool_calls_str = "\n".join(tool_calls_list)

            _logger(tool_calls_str)
        if self.images:
            _logger(f"Images added: {len(self.images)}")
        if self.videos:
            _logger(f"Videos added: {len(self.videos)}")
        if self.audio:
            _logger(f"Audio Files added: {len(self.audio)}")
        if self.files:
            _logger(f"Files added: {len(self.files)}")

        metrics_header = " TOOL METRICS " if self.role == "tool" else " METRICS "
        if metrics and self.metrics is not None and self.metrics != Metrics():
            _logger(metrics_header, center=True, symbol="*")

            # Token metrics
            token_metrics = []
            if self.metrics.input_tokens and self.metrics.input_tokens > 0:
                token_metrics.append(f"input={self.metrics.input_tokens}")
            if self.metrics.output_tokens and self.metrics.output_tokens > 0:
                token_metrics.append(f"output={self.metrics.output_tokens}")
            if self.metrics.total_tokens and self.metrics.total_tokens > 0:
                token_metrics.append(f"total={self.metrics.total_tokens}")
            if self.metrics.cache_read_tokens and self.metrics.cache_read_tokens > 0:
                token_metrics.append(f"cached={self.metrics.cache_read_tokens}")
            if self.metrics.cache_write_tokens and self.metrics.cache_write_tokens > 0:
                token_metrics.append(f"cache_write_tokens={self.metrics.cache_write_tokens}")
            if self.metrics.reasoning_tokens and self.metrics.reasoning_tokens > 0:
                token_metrics.append(f"reasoning={self.metrics.reasoning_tokens}")
            if self.metrics.audio_total_tokens and self.metrics.audio_total_tokens > 0:
                token_metrics.append(f"audio={self.metrics.audio_total_tokens}")
            if token_metrics:
                _logger(f"* Tokens:                      {', '.join(token_metrics)}")

            # Time related metrics
            if self.metrics.duration is not None and self.metrics.duration > 0:
                _logger(f"* Duration:                    {self.metrics.duration:.4f}s")
            if self.metrics.output_tokens and self.metrics.duration and self.metrics.duration > 0:
                _logger(
                    f"* Tokens per second:           {self.metrics.output_tokens / self.metrics.duration:.4f} tokens/s"
                )
            if self.metrics.time_to_first_token is not None and self.metrics.time_to_first_token > 0:
                _logger(f"* Time to first token:         {self.metrics.time_to_first_token:.4f}s")

            # Non-generic metrics
            if self.metrics.provider_metrics:
                _logger(f"* Provider metrics:            {self.metrics.provider_metrics}")
            if self.metrics.additional_metrics:
                _logger(f"* Additional metrics:          {self.metrics.additional_metrics}")

            _logger(metrics_header, center=True, symbol="*")

    def content_is_valid(self) -> bool:
        """Check if the message content is valid."""
        return self.content is not None and len(self.content) > 0
